#!/usr/bin/env python3
"""
unbalanced_ot_tv_laguerre_en_dist_asymTV.py   c(x,y) = ||x - y|| ex5.1

Semidiscrete unbalanced OT with asymmetric total variation regularization and weak OT.
Dual gradient descent implementation based on paper §3 (Theorem 3.5) using Laguerre tessellation.
Cost function changed to Euclidean distance (L2 norm) instead of squared distance.
Asymmetric TV regularization coefficients lam_plus and lam_minus replace symmetric lambda.
"""

import numpy as np
import matplotlib.pyplot as plt
import time

def main():
    # ========== Experiment Parameters ==========
    N = 10000           # Number of continuous measure μ samples
    m = 4               # Number of Dirac points for discrete measure ν
    M_mu = 1.0          # Total mass of μ
    M_nu = 0.7          # Total mass of ν (unbalanced)
    # Asymmetric TV regularization coefficients
    lam_plus = 1.0      # TV penalty for mass creation (ν_i - μ(C_i) > 0)
    lam_minus = 0.5     # TV penalty for mass destruction (ν_i - μ(C_i) < 0)
    alpha0 = 0.05       # Initial step size
    decay = 0.6         # Step size decay exponent: α_t = α0 / (1+t)^decay
    T_max = 1000        # Maximum number of iterations
    tol_grad = 1e-4     # Gradient norm tolerance
    rng = np.random.default_rng(1)

    # ========== Data Generation ==========
    X = rng.random((N, 2))          # Uniform samples of continuous measure
    Y = rng.random((m, 2))          # Random Dirac support points
    nu = rng.random(m)
    nu = nu / nu.sum() * M_nu       # Normalize weights to sum to M_nu
    w = np.zeros(m)                 # Initialize dual potentials

    # Precompute cost matrix C(x,y) = ||x - y||
    diff = X[:, None, :] - Y[None, :, :]
    C = np.linalg.norm(diff, axis=2)  # Shape (N, m)

    hist_grad = []  # Record gradient norms
    hist_mass = []  # Record total mass mismatch

    # ========== Gradient Descent Optimization ==========
    start_time = time.time()
    for t in range(T_max):
        alpha_t = alpha0 / (1 + t)**decay  # Adaptive step size

        # 1. Laguerre assignment: assign each X[n] to index i minimizing C[n,i] - w[i]
        labels = np.argmin(C - w[None, :], axis=1)
        counts = np.bincount(labels, minlength=m)           # Samples per cell
        mu_est = (M_mu / N) * counts                        # Estimated μ(C_i)

        # 2. Gradient computation: raw_grad = ν - μ(C_i)
        raw_grad = nu - mu_est
        # Asymmetric TV regularization subgradient
        # Clip above by lam_plus and below by -lam_minus
        grad = np.where(raw_grad > lam_plus, lam_plus,
                        np.where(raw_grad < -lam_minus, -lam_minus, raw_grad))
        # Update dual potentials with asymmetric projection
        w_new = w + alpha_t * grad
        w = np.where(w_new > lam_plus, lam_plus,
                     np.where(w_new < -lam_minus, -lam_minus, w_new))

        # 3. Record convergence metrics
        gnorm = np.linalg.norm(grad)
        mass_err = np.abs(raw_grad).sum()
        hist_grad.append(gnorm)
        hist_mass.append(mass_err)

        if gnorm < tol_grad:
            print(f"Converged at iteration {t}, ||grad|| = {gnorm:.2e}")
            break
        end_time = time.time()
        print(f"运行时间: {end_time - start_time:.4f} 秒")
    # ========== Visualization ==========
    # 1. Laguerre partition
    plt.figure(figsize=(6, 6))
    plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='tab10', s=1, alpha=0.5, label='Continuous samples')
    plt.scatter(Y[:, 0], Y[:, 1], c='k', marker='x', label='Dirac support', zorder=3,s=200,linewidths=3)
    plt.xlabel("$x_1$", fontsize=18)
    plt.ylabel("$x_2$", fontsize=18)
    plt.legend(loc='upper right', fontsize=18)
    plt.tick_params(axis='both', which='major', labelsize=18)
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.tight_layout()

    # 2. Dual potentials distribution
    plt.figure(figsize=(6, 6))
    plt.bar(range(m), w, label='Values of dual potentials $w_i$')
    plt.xlabel("Discrete index $i$", fontsize=18)
    plt.ylabel("$w_i$", fontsize=18)
    plt.legend(loc='upper right', fontsize=18)
    plt.tick_params(axis='both', which='major', labelsize=18)
    plt.grid(axis='y', linestyle='--', alpha=0.3)
    plt.tight_layout()

    # 3. Convergence of gradient norm
    plt.figure(figsize=(6, 6))
    plt.semilogy(hist_grad, label='Gradient norm $\\|\\nabla G\\|_2$')
    plt.xlabel("Iterations", fontsize=18)
    plt.ylabel("$\\|\\nabla G\\|_2$", fontsize=18)
    plt.legend(loc='upper right', fontsize=18)
    plt.tick_params(axis='both', which='major', labelsize=18)
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.tight_layout()

    # 4. Convergence of total mass mismatch
    plt.figure(figsize=(6, 6))
    plt.plot(hist_mass, label='Mass mismatch $\\sum_i|\\mu(C_i)-\\nu_i|$')
    plt.xlabel("Iterations", fontsize=18)
    plt.ylabel("Mass mismatch", fontsize=18)
    plt.legend(loc='upper right', fontsize=18)
    plt.tick_params(axis='both', which='major', labelsize=18)
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.tight_layout()

    # 5. Soft assignment probability heatmaps for each Dirac point
    epsilon = 0.1   # Soft assignment smoothing parameter
    grid_n = 200
    xx, yy = np.meshgrid(np.linspace(0, 1, grid_n), np.linspace(0, 1, grid_n))
    grid = np.stack([xx.ravel(), yy.ravel()], axis=1)
    diff_grid = grid[:, None, :] - Y[None, :, :]
    C_grid = np.linalg.norm(diff_grid, axis=2)
    exp_term = np.exp((w[None, :] - C_grid) / epsilon)
    P_grid = exp_term / exp_term.sum(axis=1, keepdims=True)
    for i in range(m):
        plt.figure(figsize=(6, 6))
        prob_map = P_grid[:, i].reshape(grid_n, grid_n)
        img = plt.imshow(prob_map, extent=(0,1,0,1), origin='lower', cmap='viridis')
        plt.scatter(Y[i, 0], Y[i, 1], c='red', marker='x', label=f'Discrete point $y_{i}$')
        plt.xlabel("x", fontsize=18)
        plt.ylabel("y", fontsize=18)
        plt.legend(loc='upper right', fontsize=18)
        plt.tick_params(axis='both', which='major', labelsize=18)
        cbar = plt.colorbar(img)
        cbar.set_label(f'$P(y_{i}|x,y)$', fontsize=18)
        plt.tight_layout()

    plt.show()


if __name__ == '__main__':
    main() 